# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# End-to-end convolution 2d tests.

load("//build_tools/bazel:iree_e2e_generated_runner_test.bzl", "iree_generated_e2e_runner_test")

package(
    features = ["layering_check"],
    licenses = ["notice"],  # Apache 2.0
)

py_binary(
    name = "generate_e2e_conv2d_tests",
    srcs = ["generate_e2e_conv2d_tests.py"],
)

###########################################################################
##
## LLVMCPU backend
##
###########################################################################

# Default CPU backend.
[iree_generated_e2e_runner_test(
    name = "e2e_conv2d_cpu_%s_%s_%s_%s" % (dtype, dtype, dtype, size),
    generator = ":generate_e2e_conv2d_tests",
    generator_args = [
        "--input_type=%s" % dtype,
        "--kernel_type=%s" % dtype,
        "--acc_type=%s" % dtype,
        "--shapes=%s" % size,
    ],
    tags = [
        "hostonly",
        "local",
    ],
    target_backends_and_drivers = [
        ("llvm-cpu", "local-task"),
    ],
    test_runner = "//tools/testing/e2e:iree-e2e-conv2d-test",
    test_type = "conv2d",
) for dtype in [
    "f32",
    "f16",
] for size in [
    "small",
    "medium",
    "large",
]]

# Default CPU backend + winograd.
[iree_generated_e2e_runner_test(
    name = "e2e_winograd_conv2d_cpu_%s_%s_%s_%s" % (dtype, dtype, dtype, size),
    compiler_flags = [
        "--iree-preprocessing-pass-pipeline=builtin.module\\(func.func\\(iree-linalg-ext-convert-conv2d-to-winograd{replace-all-convs=true}\\)\\)",
    ],
    generator = ":generate_e2e_conv2d_tests",
    generator_args = [
        "--input_type=%s" % dtype,
        "--kernel_type=%s" % dtype,
        "--acc_type=%s" % dtype,
        "--shapes=%s" % size,
    ],
    tags = [
        "hostonly",
        "local",
    ],
    target_backends_and_drivers = [
        ("llvm-cpu", "local-task"),
    ],
    test_runner = "//tools/testing/e2e:iree-e2e-conv2d-test",
    test_type = "conv2d",
) for dtype in [
    "f32",
    "f16",
] for size in [
    "small",
    "medium",
    "large",
]]
